Code
include("../utils.jl")
import MLJ:predict,predict_mode
import BetaML
using DataFrames,MLJ,CSV,MLJModelInterface,GLMakie
using CatBoost.MLJCatBoostInterface
使用 BetaMLjl
库 on german-creditcard
dataset
coerce(d,autotype(d, (:few_to_finite, :discrete_to_continuous)))
BetaML
是julia
中另一个大型的机器学习库,参考文档:BetaML Doc
include("../utils.jl")
import MLJ:predict,predict_mode
import BetaML
using DataFrames,MLJ,CSV,MLJModelInterface,GLMakie
using CatBoost.MLJCatBoostInterface
= load_german_creditcard(); Xtrain, Xtest, ytrain, ytest,cat
function define_models()
= @load NeuralNetworkClassifier pkg = "BetaML"
modelType1
= [BetaML.DenseLayer(19,8,f=BetaML.relu),BetaML.DenseLayer(8,8,f=BetaML.relu),BetaML.DenseLayer(8,2,f=BetaML.relu),BetaML.VectorFunctionLayer(2,f=BetaML.softmax)];
layers= modelType1(layers=layers,opt_alg=BetaML.ADAM())
nn_model
= @load DecisionTreeClassifier pkg = "BetaML" verbosity=0
modelType2= modelType2()
dt_model
= @load KernelPerceptron pkg = "BetaML"
modelType3= modelType3()
kp_model
= @load LinearPerceptron pkg = "BetaML"
modelType4= modelType4()
lp_model
= @load Pegasos pkg = "BetaML" verbosity=0
modelType5=modelType5()
peg_model
= @load RandomForestClassifier pkg = "BetaML" verbosity=0
modelType6=modelType6()
rf_model
=CatBoostClassifier(iterations=5)
cat_model
=[nn_model,dt_model,kp_model,lp_model,peg_model,rf_model,cat_model]
models=["nn","dt","kp","lp","peg","rf","cat"]
models_namereturn models,models_name
end
=define_models() models,models_name
[ Info: For silent loading, specify `verbosity=0`.
[ Info: For silent loading, specify `verbosity=0`.
[ Info: For silent loading, specify `verbosity=0`.
import BetaML ✔
import BetaML ✔
import BetaML ✔
(Probabilistic[NeuralNetworkClassifier(layers = BetaML.Nn.AbstractLayer[BetaML.Nn.DenseLayer([-0.36074636511061065 -0.39582846276007433 -0.19367369489258768 0.2680078206831555 0.23415797915589948 -0.35426717723464507 0.41873972936820797 0.03695708843884965 -0.1262241154740001 -0.4535341279511609 0.27083756550881694 -0.26285190070170183 -0.29113824789040094 0.3272732026253102 0.3120181348536681 0.03189504313141717 0.17859507646360423 0.24100348772661923 0.014525967665540818; 0.02585857761506677 -0.1786289777618676 0.27390992608661585 -0.00882009262935829 -0.05597939255352807 -0.45185469645137644 0.43448083345893546 -0.07933636613135336 -0.40808183079083227 -0.1512087063152568 -0.22346202440402502 0.24282973121527757 -0.4611082088959598 -0.22314239815908504 -0.2040062087828935 -0.23716874176485098 -0.4007849436405437 -0.029299013620295244 -0.2240163921875015; 0.3586660049153875 0.0990446439566613 -0.30304464083062305 -0.35711866133831294 -0.18581141830140957 0.10300116526719799 -0.3547092905827105 0.17358444569480663 0.46816569231449073 0.36725354636566204 -0.09950624589607898 -0.18348423586113166 -0.05591644911594651 0.21540326133872428 0.4397024519627983 -0.019152022051641848 0.015081037200303848 -0.09338789491179 0.2698572766549305; -0.14737750284903278 0.15390901525950657 0.11733088463889524 -0.18553420478116528 0.2773342867980156 -0.44187683747564216 -0.07128222618035901 -0.3975783832585831 -0.1797548030782582 0.45414987217334984 0.11774352278251005 0.4318637917472961 0.34114366748895236 0.20543094312431615 0.45879918119036506 0.19966031284458824 -0.4243977564030008 0.3528498049246546 -0.26449251745539387; -0.46603077560973166 0.363177679644008 0.15387610207235164 0.16878070443105436 0.22469025601685005 0.43163009427259763 0.3433671636225248 -0.04331120409999056 -0.3450286254110857 0.27013134708859105 -0.1246971621378602 0.18019183676577127 -0.22935636935801848 -0.38713683379338093 -0.36324994527472265 0.12175941762041481 0.37182252248090214 -0.2946228229043437 0.36728385779700773; -0.35965444222821236 -0.2513459765142385 -0.12760789988466997 -0.3651832232912821 0.2275393345538726 0.19344000211089812 -0.06541914124742532 0.45592078923421525 -0.3588317610866417 0.37053197119661924 -0.16897336019104503 0.006708286081282766 0.22294499850839905 0.3064633029426777 0.449684771230571 0.29878199021450175 -0.19723891228337898 0.03398494223628351 0.195930874969185; 0.16768892675786945 0.20178877728541117 -0.01920955698773008 -0.07346464258077734 -0.22641700608346715 -0.3972909747656714 -0.3926505747242142 -0.3211783657517294 -0.42299054770886885 0.4394603691078505 -0.020340987431912816 0.31152821754851673 0.275435690977708 -0.453777188734648 0.38653852340376965 0.2257564654780106 -0.39273527136599073 -0.3588904629697215 0.04673991700334451; -0.3922339208951156 0.4104731956148045 0.01849632586826505 -0.2311936494942672 0.05614872575245905 -0.10645646149582344 -0.4578158813458524 0.17582732885781888 0.007762622179870338 0.1398437234045315 -0.2026471964256374 0.3674505083332545 -0.26909992476264555 -0.20785693976235886 0.07734529876859592 0.08917618938717092 -0.22274107976966068 -0.3823399200961872 -0.45505603896994484], [-0.157465439622054, -0.21995520319116818, -0.3384789069755978, -0.15006788522938158, -0.08068739583553858, 0.17488394613124297, -0.04748199633913597, -0.3250052328690887], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.3852959591608028 -0.553059892716297 -0.015028124586988656 -0.011201024404014936 0.005100673463055649 -0.18746389838773353 -0.4465071177773833 -0.24705649867006285; 0.6026146990300689 -0.33870198929549894 0.2718518742522277 -0.3384705743525871 -0.06814271449805787 0.606002135337379 -0.10336455886078566 0.26238572386113823; 0.45018281052830456 0.18679055621241314 -0.37362480422350963 0.006078042783034587 -0.45064052026604706 -0.1425599823506864 -0.3585911486603981 -0.6063601118282695; -0.05738177001216915 -0.3564277181599287 -0.36781688568508414 -0.14252074674595333 0.2571075127666419 0.19183099681926707 -0.018662422915370347 -0.3510293358769623; -0.08061251239900558 0.1179048173797197 -0.25520661838834113 -0.02273079835591796 0.38222026437329126 0.40094107662179646 -0.5807559333068457 0.36662125095206255; 0.33297001370651136 0.33740265915614154 0.11892808620953843 -0.461833943645412 -0.44731582380216783 -0.1810395086037453 0.2447640455667498 0.18200413571170693; -0.48967054357747297 -0.4865518207276439 -0.24790297957249452 -0.45395785625464546 -0.21218635943332892 -0.11260342698590892 -0.38292860160231035 -0.5607840330610061; 0.49172091443161037 0.348025736136519 0.04600440542299833 -0.5245141418067015 -0.16209562928447985 -0.16401386904196824 -0.08199421251258743 0.043141435301436104], [0.5320229402035459, 0.08848618862963409, 0.49077131013907294, 0.40163369167512397, -0.3241193921715959, -0.4400009135782846, 0.07456092888227783, 0.36052644514730536], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.3367271772548713 0.6877326959270641 0.1575891073545942 -0.4441500273567661 0.20426279244526513 -0.08785955388176736 0.09549279041837222 -0.25434472335021696; -0.21366575215568595 0.01925661356946584 0.654401524198115 0.07698019071966455 0.006405545613967667 -0.4771138208850888 -0.7483462850040817 0.09549653596307062], [-0.604516114936473, -0.11748999487867229], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.VectorFunctionLayer{0}(fill(NaN), 2, 2, BetaML.Utils.softmax, BetaML.Utils.dsoftmax, nothing)], …), DecisionTreeClassifier(max_depth = 0, …), KernelPerceptron(kernel = radial_kernel, …), LinearPerceptron(initial_coefficients = nothing, …), Pegasos(initial_coefficients = nothing, …), RandomForestClassifier(n_trees = 30, …), CatBoostClassifier(iterations = 5, …)], ["nn", "dt", "kp", "lp", "peg", "rf", "cat"])
function train_model()
for (idx,model) in enumerate(models[1:6])
local (fitResults, cache, report) = MLJ.fit(model, 0, Xtrain,ytrain);
local est_classes= predict_mode(model, fitResults, Xtest)
local acc=accuracy(ytest,est_classes)|>d->round(d, digits=3)
@info "$(models_name[idx])===>$(acc)"
end
end
train_model()
[ Info: nn===>0.305
[ Info: dt===>0.705
[ Info: kp===>0.335
[ Info: lp===>0.45
[ Info: peg===>0.695
[ Info: rf===>0.735